#' ---
#' title: "Greenwashing the Future? Computational Text Analysis of Environmental Reporting from the Fossil Fuel Industry"
#' subtitle: "02_performance_metrics.R"
#' author: "Robin Rauner"
#' date: "Note: Code compiled successfully on `r format(Sys.time(), '%d %B %Y')`"
#' ---

# load packages
library(tidyverse) # CRAN v2.0.0
library(rio)       # CRAN v1.2.2
library(mltest)    # CRAN v1.0.1                
library(tidycomm)  # CRAN v0.4.1
library(xtable)    # CRAN v1.8.4

sessionInfo()

# Function to calculate and return performance metrics
get_performance_metrics <- function(data, true_var, pred_var) {
    # calculate standard metrics
    metrics <- ml_test(data[[pred_var]], data[[true_var]])

    # transform to long format
    data_long <- data |>
        pivot_longer(
            cols = all_of(c(pred_var, true_var)),
            names_to = "coder",
            values_to = "coded_policy"
        )

    # calculate Krippendorff's alpha
    icr_metrics <- data_long |>
        test_icr(doc_id, coder, coded_policy)

    # return a list of selected metrics
    data.frame(
        Metrics = c("Accuracy", "F1 (Average)", "F1 = 0", "F1 = 1", "Precision (Average)",
                    "Precision = 0", "Precision = 1", "Recall (Average)", "Recall = 0",
                    "Recall = 1", "Krippendorff's Alpha"),
        Scores = c(metrics$accuracy, mean(metrics$F1), metrics$F1[1], 
                  metrics$F1[2], mean(metrics$precision), metrics$precision[1], 
                  metrics$precision[2], mean(metrics$recall), metrics$recall[1], 
                  metrics$recall[2], icr_metrics$Krippendorffs_Alpha)
    )
}

# Import and prepare data
dat_validation_1 <- read.csv("validation_set.csv") |>
    select(-company)

dat_validation_2 <- read.csv("held_out_validation_set.csv") |>
    select(-company)


dat_coder2_1 <- read.csv("validation_set_coder_2.csv") |>
    rename(fls_coder2 = fls) |>
    select(-company)

dat_coder2_2 <- read.csv("held_out_validation_set_coder_2.csv") |>
    rename(fls_coder2 = fls) |>
    select(-company)


dat_r <- import("handcoded_set_R.csv") |>
    rename(fls_r = fls, ambiguous_r = ambiguous)

dat_k <- import("handcoded_set_K.csv") |>
    rename(fls_k = fls, ambiguous_k = ambiguous)

dat_llm <- import("llm_annotated_data.csv") |>
    rename(doc_id = speech_id, fls_llm = fls, text = speech_text)
names(dat_llm)

dat_dict <- import("dictionary_annotated_data.csv") |>
    rename(fls_dict = fls) |>
    select(-c(V1, company))
names(dat_dict)

## Inter-coder reliability
# Merge coded data
dat_test <- dat_validation_1 |>
    left_join(dat_coder2_1 |> select(-text), by = "doc_id") |>
    left_join(dat_llm |> select(-text), by = "doc_id") |>
    left_join(dat_dict |> select(-text), by = "doc_id")
names(dat_test)

dat_heldout <- dat_validation_2 |>
    left_join(dat_coder2_2 |> select(-text), by = "doc_id") |>
    left_join(dat_llm |> select(-text), by = "doc_id") |>
    left_join(dat_dict |> select(-text), by = "doc_id")
names(dat_heldout)

dat_icr <- dat_r |>
    left_join(dat_k |> select(-text), by = "doc_id")

# Inspect
summary(dat_icr)


## Get test and held-out set metrics
llm_set1 <- get_performance_metrics(dat_test, "fls", "fls_llm") |>
    rename(LLM = Scores)
llm_set2 <- get_performance_metrics(dat_heldout, "fls", "fls_llm") |>
    rename(LLM = Scores)
dict_set1 <- get_performance_metrics(dat_test, "fls", "fls_dict") |>
    rename(Dictionary = Scores)
dict_set2 <- get_performance_metrics(dat_heldout, "fls", "fls_dict") |>
    rename(Dictionary = Scores)
coder2_set1 <- get_performance_metrics(dat_test, "fls", "fls_coder2") |>
    rename(Human = Scores)
coder2_set2 <- get_performance_metrics(dat_heldout, "fls", "fls_coder2") |>
    rename(Human = Scores)

# Combine metrics into a single table
tab_test_heldout <- dict_set1 |>
    left_join(llm_set1, by = "Metrics") |>
    left_join(coder2_set1, by = "Metrics") |>
    left_join(dict_set2, by = "Metrics") |>
    left_join(llm_set2, by = "Metrics") |>
    left_join(coder2_set2, by = "Metrics")

# Save LaTeX table
print(
    xtable(tab_test_heldout,
        caption = "Comparison of classifier performance",
        label = "tab:performance_test_heldout", digits = 2
    ),
    file = "performance_metrics_test_heldout.tex",
    include.rownames = FALSE,
    booktabs = TRUE,
    caption.placement = "top"
)

# Transform ICR data to long format
dat_long <- dat_icr |>
    select(-starts_with("ambiguous")) |>
    pivot_longer(
        cols = c(fls_r, fls_k),
        names_to = "coder",
        values_to = "coded_fls"
    )

head(dat_long)

icr_metrics <- dat_long |>
    test_icr(doc_id, coder, coded_fls)

icr_metrics$Krippendorffs_Alpha # 0.67

# Inspect ambiguous cases
dat_icr |>
    filter(ambiguous_r == 1 | ambiguous_k == 1) |>
    nrow()

dat_icr |>
    filter(ambiguous_r == 1 & ambiguous_k == 1) |> nrow()

# Where fls_r != fls_k and ambiguous == 1, assign 0
dat_icr_clear <- dat_icr |>
    mutate(
        ambiguous = case_when(ambiguous_r == 1 | ambiguous_k == 1 ~ 1, TRUE ~ 0),
        fls_r = ifelse(ambiguous == 1 & fls_r != fls_k, 0, fls_r),
        fls_k = ifelse(ambiguous == 1 & fls_k != fls_r, 0, fls_k)
    )

dat_icr_clear |>
    select(fls_r, fls_k, ambiguous) |>
    slice_sample(n = 25)

summary(dat_icr_clear)

# Transform to long format
dat_long_clear <- dat_icr_clear |>
    select(-starts_with("ambiguous")) |>
    pivot_longer(
        cols = c(fls_r, fls_k),
        names_to = "coder",
        values_to = "coded_fls"
    )

head(dat_long_clear)

# Calculate Krippendorff's alpha
icr_metrics_clear <- dat_long_clear |>
    test_icr(doc_id, coder, coded_fls)

icr_metrics_clear$Krippendorffs_Alpha # 0.79


## Use revised hand-coding to calculate performance metrics
dat_icr_all <- dat_icr_clear |>
    left_join(dat_llm |> select(-text), by = "doc_id") |>
    left_join(dat_dict |> select(-text), by = "doc_id")
names(dat_icr_all)

metrics_llm <- get_performance_metrics(dat_icr_all, "fls_r", "fls_llm") |>
    rename(LLM = Scores)
metrics_llm
metrics_dict <- get_performance_metrics(dat_icr_all, "fls_r", "fls_dict") |>
    rename(Dictionary = Scores)
metrics_dict

metrics_tab <- metrics_dict |>
    left_join(metrics_llm, by = "Metrics")
metrics_tab

# Inspect metrics
metrics <- ml_test(dat_icr_all$fls_dict, dat_icr_all$fls_r)
metrics$accuracy
metrics$F1
metrics$precision
metrics$recall

# Save LaTeX table (Table 1)
print(
    xtable(metrics_tab,
        caption = "Comparison of classifier performance (sample n = 250)",
        label = "tab:performance", digits = 2
    ),
    file = "performance_metrics_n250.tex",
    include.rownames = FALSE,
    booktabs = TRUE,
    caption.placement = "top"
)